# coding=utf8
from __future__ import division
import os
import torch
import torch.utils.data as data
import numpy as np
import pandas as pd
import cv2
from collections import OrderedDict
from itertools import chain
cv2.setNumThreads(0)

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

attr2idx_map = {'coat_length_labels':0,
            'collar_design_labels':1,
            'lapel_design_labels':2,
            'neck_design_labels':3,
            'neckline_design_labels':4,
            'pant_length_labels':5,
            'skirt_length_labels':6,
            'sleeve_length_labels':7}

idx2attr_map = OrderedDict({0:'coat_length_labels',
                            1:'collar_design_labels',
                            2:'lapel_design_labels',
                            3:'neck_design_labels',
                            4:'neckline_design_labels',
                            5:'pant_length_labels',
                            6:'skirt_length_labels',
                            7:'sleeve_length_labels'})

attr2length_map = { 0: 8,
                    1:5,
                    2:5,
                    3:5,
                    4: 10,
                    5:6,
                    6:6,
                    7:9}

attr2catidx_map = { 0: [0,8],
                    1:[8,13],
                    2:[13,18],
                    3:[18,23],
                    4: [23,33],
                    5:[33,39],
                    6:[39,45],
                    7:[45,54]}

label_map = {'n':0., 'm':0.5, 'y':1.}


class FSdata(data.Dataset):
    def __init__(self, anno_pd, transforms=None, label_map=label_map, select=range(0,8)):
        self.paths = anno_pd['ImageName'].tolist()
        self.attrs = anno_pd['AttrKey'].tolist()
        self.labels = anno_pd['AttrValues'].tolist()
        self.transforms = transforms
        self.label_map = label_map
        self.select = select
        self.catlen = 0
        self.catidx_map = {}
        for attr_idx in select:
            self.catidx_map[attr_idx] = [self.catlen, self.catlen+attr2length_map[attr_idx]]
            self.catlen += attr2length_map[attr_idx]

        self.num_classes = [attr2length_map[x] for x in select]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, item):
        img_path =self.paths[item]

        img = cv2.cvtColor(cv2.imread(img_path),cv2.COLOR_BGR2RGB)   # [h,w,3]  RGB

        attr = self.attrs[item]
        attr_idx = attr2idx_map[attr]

        label_str = [s for s in self.labels[item]]
        label = np.array([self.label_map[s] for s in self.labels[item]])

        label = label / label.sum()

        attr_mask = np.zeros(self.catlen, dtype=np.int32)
        attr_mask[self.catidx_map[attr_idx][0]:self.catidx_map[attr_idx][1]] = 1

        cat_label = np.zeros(self.catlen, dtype=np.float32)
        cat_label[self.catidx_map[attr_idx][0]:self.catidx_map[attr_idx][1]] = label

        cat_label_str = ['n']*self.catlen
        cat_label_str[self.catidx_map[attr_idx][0]:self.catidx_map[attr_idx][1]] = label_str

        if self.transforms is not None:
            img, attr_idx = self.transforms(img, attr_idx)

        return torch.from_numpy(img).float(), \
               attr_idx, \
               torch.from_numpy(attr_mask).float(), \
               torch.from_numpy(cat_label).float(), \
               cat_label_str



class FSdata_aux(data.Dataset):
    def __init__(self, root_path, anno_pd, transforms=None, label_map=label_map, select=range(0,8),
                 auxignore=None):
        self.root_path = root_path
        self.paths = anno_pd['ImageName'].tolist()
        self.attrs = anno_pd['AttrKey'].tolist()
        self.labels = anno_pd['AttrValues'].tolist()
        self.transforms = transforms
        self.label_map = label_map
        self.select = select
        self.catlen = 0
        self.catidx_map = {}
        self.auxlabel_map = {}
        for i,attr_idx in enumerate(select):
            self.auxlabel_map[attr_idx] = i
            self.catidx_map[attr_idx] = [self.catlen, self.catlen+attr2length_map[attr_idx]]
            self.catlen += attr2length_map[attr_idx]

        self.num_classes = [attr2length_map[x] for x in select]
        self.auxignore = auxignore

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, item):
        img_path = os.path.join(self.root_path, self.paths[item])

        img = cv2.cvtColor(cv2.imread(img_path),cv2.COLOR_BGR2RGB)   # [h,w,3]  RGB

        attr = self.attrs[item]
        attr_idx = attr2idx_map[attr]

        label_str = [s for s in self.labels[item]]
        label = np.array([self.label_map[s] for s in self.labels[item]])

        label = label / label.sum()

        attr_mask = np.zeros(self.catlen, dtype=np.int32)
        attr_mask[self.catidx_map[attr_idx][0]:self.catidx_map[attr_idx][1]] = 1

        cat_label = np.zeros(self.catlen, dtype=np.float32)
        cat_label[self.catidx_map[attr_idx][0]:self.catidx_map[attr_idx][1]] = label

        cat_label_str = ['n']*self.catlen
        cat_label_str[self.catidx_map[attr_idx][0]:self.catidx_map[attr_idx][1]] = label_str

        if self.transforms is not None:
            img, attr_idx = self.transforms(img, attr_idx)

        auxlabel = np.zeros(len(self.auxlabel_map),dtype=int)

        if not (attr_idx in self.auxignore.keys() and self.labels[item] in self.auxignore[attr_idx]):
            auxlabel[self.auxlabel_map[attr_idx]] = 1


        return torch.from_numpy(img).float(), \
               attr_idx, \
               torch.from_numpy(attr_mask).float(), \
               torch.from_numpy(cat_label).float(), \
               cat_label_str,\
               torch.from_numpy(auxlabel).float()


def collate_fn(batch):
    imgs = []
    attr_idx = []
    attr_mask = []
    cat_label = []
    cat_label_str = []

    for sample in batch:
        imgs.append(sample[0])
        attr_idx.append(sample[1])
        attr_mask.append(sample[2])
        cat_label.append(sample[3])
        cat_label_str.append(sample[4])

    return torch.stack(imgs, 0), \
           attr_idx, \
           torch.stack(attr_mask, 0), \
           torch.stack(cat_label, 0), \
           cat_label_str


def collate_fn_aux(batch):
    imgs = []
    attr_idx = []
    attr_mask = []
    cat_label = []
    cat_label_str = []
    auxlabel = []

    for sample in batch:
        imgs.append(sample[0])
        attr_idx.append(sample[1])
        attr_mask.append(sample[2])
        cat_label.append(sample[3])
        cat_label_str.append(sample[4])
        auxlabel.append(sample[5])

    return torch.stack(imgs, 0), \
           attr_idx, \
           torch.stack(attr_mask, 0), \
           torch.stack(cat_label, 0), \
           cat_label_str, \
           torch.stack(auxlabel)

if __name__ == '__main__':
    from FSdata.FSaug import *
    class FSAug(object):
        def __init__(self):
            self.augment = Compose([
                Resize(size=(336, 336), select=[0, 1, 2, 5, 6, 7]),
                RandomUpperCrop(size=(336, 336), select=[3, 4]),
                RandomHflip(),
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

        def __call__(self, image, attr_idx):
            return self.augment(image, attr_idx)


    rawdata_root = '/media/gserver/data/FashionAI'
    all_pd = pd.read_csv(os.path.join(rawdata_root, 'base/Annotations/label.csv'),
                         header=None, names=['ImageName', 'AttrKey', 'AttrValues'])
    all_pd['ImageName'] = all_pd['ImageName'].apply(lambda x: os.path.join('base', x))

    # select part
    select_AttrIdx = [3, 4]
    select_AttrKey = [idx2attr_map[x] for x in select_AttrIdx]
    all_pd = all_pd[all_pd['AttrKey'].apply(lambda x: True if x in select_AttrKey else False)]

    data_set = {}
    data_set['train'] = FSdata_aux(root_path=rawdata_root,
                               anno_pd=all_pd,
                               transforms=FSAug(),
                               select=[3,4],
                               auxignore={3:['nnnyn'],4:['nny'+'n'*7]
                                          }
                               )

    #
    data_loader = {}
    data_loader['train'] = data.DataLoader(data_set['train'], 3, num_workers=1,
                                                shuffle=False, pin_memory=True, collate_fn=collate_fn_aux)
    #
    for i,batch in enumerate(data_loader['train']):
        print i
        print batch
        break